#  This file is part of the NssMPClib project.
#  Copyright (c) 2024 XDU NSS lab,
#  Licensed under the MIT license. See LICENSE in the project root for license information.

"""
Arithmetic secretly share triples and Replicated secretly share triples.
"""
import os
import random
import re
from NssMPC.crypto.aux_parameter import Parameter
from NssMPC.common.ring import RingTensor
from NssMPC.common.utils import cuda_matmul
from NssMPC.config import param_path
from NssMPC.crypto.primitives.homomorphic_encryption.paillier import Paillier


class AssMulTriples(Parameter):
    """
    This is a parameter class for ASS multiplication triples, which allows the generation and saving of ASS multiplication triples.
    """

    def __init__(self):
        """
        Initializes the a, b, and c of the multiplicative triple and sets the size of the elements in the triple to 0.
        """
        self.a = None
        self.b = None
        self.c = None
        self.size = 0

    def __iter__(self):
        """
        Iterator support.

        Allows the BooleanTriples instance to be iterated over, returning each element of the triple.

        :return: Returns an iterator for boolean triples.
        :rtype: iterator
        """
        return iter((self.a, self.b, self.c))

    # def set_name(self, name):
    #     self.name = name
    #
    def set_party(self, party):
        """
        Set party properties for parameters a, b, c.

        :param party: the party to generate triples
        :type party: Party
        """
        self.a.party = party
        self.b.party = party
        self.c.party = party

    #
    def set_triples(self, a, b, c):
        """
        Set attributes a, b, c of the class.
        """
        self.a = a
        self.b = b
        self.c = c

    @staticmethod
    def gen(num_of_triples, num_of_party=2, type_of_generation='TTP', party=None):
        """
        Generate a specified number of Boolean triples according to different generation methods.

        Call the corresponding function based on the value of ``type_of_generation`` to generate Boolean triples:
            * :func:`gen_triples_by_homomorphic_encryption` : Generate triples based on homomorphic encryption.
            * :func:`gen_triples_by_ttp` : Generates triples based on trusted third parties.

        :param num_of_triples: number of triples
        :type num_of_triples: int
        :param num_of_party: number of parties
        :type num_of_party: int
        :param type_of_generation: generation type. (TTP: generated by trusted third party, HE: generated by homomorphic encryption)
        :type type_of_generation: str
        :param party: party if HE
        :type party: Party
        """
        if type_of_generation == 'HE':
            return gen_triples_by_homomorphic_encryption(num_of_triples, party)
        elif type_of_generation == 'TTP':
            return gen_triples_by_ttp(num_of_triples, num_of_party, share_type='normal')

    @classmethod
    def gen_and_save(cls, num, num_of_party=2, saved_name=None, saved_path=None, type_of_generation='TTP', party=None):
        """
        Generates a Boolean triple and saves it to the specified path.

        First asserts that ``type_of_generation`` is a valid value ('HE' or 'TTP'), and then generates the required
        Boolean triple. If saved_path is not specified, the default path is used. If the path does not exist,
        create it.
            * For triples generated by 'TTP', loop through each participant and save the file.
            * For triples generated by 'HE', a new file name is created for saving based on the ID of the current participant.

        :param num: number of triples
        :type num: int
        :param saved_name: Optional file name to save.
        :type saved_name: str
        :param type_of_generation: generation type. (TTP: generated by trusted third party, HE: generated by homomorphic encryption)
        :type type_of_generation: str
        :param saved_path: saved path for the generated params
        :type saved_path: str
        :param party: party if HE
        :type party: Party
        """
        triples = cls.gen(num, num_of_party, type_of_generation, party)
        if saved_path is None:
            file_path = f"{param_path}AssMulTriples/"
        else:
            file_path = saved_path
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        if saved_name is None:
            saved_name = 'AssMulTriples'

        if type_of_generation == 'TTP':
            for i in range(num_of_party):
                file_name = f"{saved_name}_{i}.pth"
                triples[i].save(file_path=file_path, name=file_name)

        # todo HE部分没有做适配
        elif type_of_generation == 'HE':
            file_names = os.listdir(file_path)
            max_ptr = 0
            for fname in file_names:
                match = re.search(rf"AssMulTriples{party.party_id}+_(\d+)\.pth", fname)
                if match:
                    max_ptr = max(max_ptr, int(match.group(1)))
            file_name = f"AssMulTriples_{party.party_id}_{max_ptr + 1}.pth"
            triples.save(file_path=file_path, name=file_name)


class RssMulTriples(Parameter):
    """
    This is a parameter class for RSS multiplication triples, which allows the generation and saving of RSS multiplication triples.
    """

    def __init__(self):
        """
        Initializes the a, b, and c of the multiplicative triple and sets the size of the elements in the triple to 0.
        """
        self.a = None
        self.b = None
        self.c = None
        self.size = 0

    def __iter__(self):
        """
        Iterator support.

        Allows the BooleanTriples instance to be iterated over, returning each element of the triple.

        :return: Returns an iterator for boolean triples.
        :rtype: iterator
        """
        return iter((self.a, self.b, self.c))

    def set_party(self, party):
        """
        Set party properties for parameters a, b, c.

        :param party: the party to generate triples
        :type party: Party
        """
        self.a.party = party
        self.b.party = party
        self.c.party = party

    def set_triples(self, a, b, c):
        """
        Set attributes a, b, c of the class.
        """
        self.a = a
        self.b = b
        self.c = c

    @staticmethod
    def gen(num_of_triples):
        """
        Generate multiplicative Beaver triples

        Generate triples using trusted third parties.

        :param num_of_triples: the number of triples
        :type num_of_triples: int
        """

        return gen_triples_by_ttp(num_of_triples, num_of_party=3, share_type='replicate')

    @classmethod
    def gen_and_save(cls, num, type_of_generation='TTP', party=None):
        """
        Generates the multiplicative Beaver triple and saves it to the specified path.

        First, the ``gen`` method of the class is called to generate the triplet. If the generation type is *TTP*,
        the file path used to store the triplet is constructed to check if the path exists, and if it does not,
        the directory is created. Iterate over the three parts of the triple, constructing a file name for each part
        and saving it to the specified path.

        :param num: number of triples
        :type num: int
        :param type_of_generation: generation type. (TTP: generated by trusted third party, HE: generated by homomorphic encryption)
        :type type_of_generation: str
        :param party: party if HE
        :type party: Party
        """
        triples = cls.gen(num)
        if type_of_generation == 'TTP':
            file_path = f"{param_path}RssMulTriples/"
            if not os.path.exists(file_path):
                os.makedirs(file_path)
            for i in range(3):
                file_name = f"RssMulTriples_{i}.pth"
                triples[i].save(file_path, file_name)


def gen_triples_by_ttp(num_of_triples, num_of_party, share_type='normal'):
    """
    Generate the multiplication Beaver triple by trusted third party.

    First, two random RingTensor, ``a`` and ``b``, are randomly generated. Then ``c = a * b`` is calculated.Then
    Select a different secret share method for distribution according to the value of ``share_type``:
        * Normal Sharing: Share using ArithmeticSecretSharing, where each participant receives part of the data.
        * Replication Sharing: Share using ReplicatedSecretSharing, where all participants receive the same data.

    Each participant generates an AssMulTriples object, storing the shared a, b, c, and size information.

    :param num_of_triples: number of triples
    :type num_of_triples: int
    :param num_of_party: number of parties
    :type num_of_party: int
    :param share_type: share type, normal or replicate
    :type share_type: str
    :returns: the multiplication triple
    """

    a = RingTensor.random([num_of_triples])
    b = RingTensor.random([num_of_triples])
    c = a * b
    if share_type == 'normal':
        from NssMPC.crypto.primitives.arithmetic_secret_sharing import ArithmeticSecretSharing
        a_list = ArithmeticSecretSharing.share(a, num_of_party)
        b_list = ArithmeticSecretSharing.share(b, num_of_party)
        c_list = ArithmeticSecretSharing.share(c, num_of_party)
        triples = []
        for i in range(num_of_party):
            triples.append(AssMulTriples())
            triples[i].a = a_list[i].to('cpu')
            triples[i].b = b_list[i].to('cpu')
            triples[i].c = c_list[i].to('cpu')
            triples[i].size = num_of_triples
        return triples
    elif share_type == 'replicate':
        from NssMPC.crypto.primitives.arithmetic_secret_sharing import ReplicatedSecretSharing
        a_list = ReplicatedSecretSharing.share(a)
        b_list = ReplicatedSecretSharing.share(b)
        c_list = ReplicatedSecretSharing.share(c)
        triples = []
        for i in range(num_of_party):
            triples.append(RssMulTriples())
            triples[i].a = a_list[i].to('cpu')
            triples[i].b = b_list[i].to('cpu')
            triples[i].c = c_list[i].to('cpu')
            triples[i].size = num_of_triples
        return triples
    else:
        raise TypeError("Share type not supported")


def gen_triples_by_homomorphic_encryption(num_of_triples, party):
    """
    Generate the multiplication Beaver triple by homomorphic encryption.

    First, two arrays a and b are randomly generated.
        If the current participant`s party_id is **0**:
                First, the key of Paillier homomorphic encryption is generated to encrypt a and b, and then the encrypted
                value is sent to other participants, and the data from other participants is received, and the value of c is
                decrypted.

        If the current participant`s party_id is **1**:
            The random value ``r`` is first generated to compute the partial value of ``c``. After receiving the message
            from Party 0, ``r`` is encrypted using homomorphic encryption, and a new encrypted value ``d`` is calculated
            based on the received data, and then sent back. Each participant's values are converted to tensors on
            the *CPU* device and stored in the triples.

    :param num_of_triples: number of triples
    :type num_of_triples: int
    :param party: the party to generate triples
    :type party: str
    :returns: An AssMulTriples that includes a, b, and c.
    :rtype: AssMulTriples
    """
    a = [random.randint(0, 2 ^ 32) for _ in range(num_of_triples)]
    b = [random.randint(0, 2 ^ 32) for _ in range(num_of_triples)]
    c = []

    if party.party_id == 0:
        paillier = Paillier()
        paillier.gen_keys()

        encrypted_a = paillier.encrypt(a)
        encrypted_b = paillier.encrypt(b)

        party.send([encrypted_a, encrypted_b, paillier.public_key])

        d = party.receive()
        decrypted_d = paillier.decrypt(d)
        c = [decrypted_d[i] + a[i] * b[i] for i in range(num_of_triples)]

    elif party.party_id == 1:

        r = [random.randint(0, 2 ^ 32) for _ in range(num_of_triples)]
        c = [a[i] * b[i] - r[i] for i in range(num_of_triples)]

        messages = party.receive()

        encrypted_r = Paillier.encrypt_with_key(r, messages[2])
        d = [messages[0][i] ** b[i] * messages[1][i] ** a[i] * encrypted_r[i] for i in range(num_of_triples)]

        party.send(d)

    triples = AssMulTriples()
    from NssMPC.crypto.primitives.arithmetic_secret_sharing import ArithmeticSecretSharing
    triples.a = ArithmeticSecretSharing(RingTensor(a).to('cpu'), party)
    triples.b = ArithmeticSecretSharing(RingTensor(b).to('cpu'), party)
    triples.c = ArithmeticSecretSharing(RingTensor(c).to('cpu'), party)
    triples.size = num_of_triples

    return triples


class MatmulTriples(AssMulTriples):
    @staticmethod
    def gen(num_of_triples, x_shape=None, y_shape=None, num_of_party=2):
        """
        Generate multiplicative Beaver triples.

        Generate matrix multiplication triples using trusted third parties.

        :param num_of_triples: number of triples
        :type num_of_triples: int
        :param num_of_party: number of parties
        :type num_of_party: int
        :param x_shape: the shape of the matrix x
        :type x_shape: tuple
        :param y_shape: the shape of the matrix y
        :type y_shape: tuple
        """
        return gen_matrix_triples_by_ttp(num_of_triples, x_shape, y_shape, num_of_party)

    # @classmethod
    # def gen_and_save(cls, num_of_triples, x_shape=None, y_shape=None, num_of_party=2):
    #     """
    #     Generate and save multiplicative Beaver triples
    #
    #     Args:
    #         num_of_triples: the number of triples
    #         num_of_party: the number of parties
    #         x_shape: the shape of the matrix x
    #         y_shape: the shape of the matrix y
    #     """
    #     triples = cls.gen(num_of_triples, x_shape, y_shape, num_of_party)
    #     for party_id in range(num_of_party):
    #         file_path = base_path + f"/aux_parameters/BeaverTriples/{num_of_party}party/Matrix"
    #         file_name = f"MatrixBeaverTriples_{party_id}_{list(x_shape)}_{list(y_shape)}.pth"
    #         triples[party_id].save_by_name(file_name, file_path)


def gen_matrix_triples_by_ttp(num_of_param, x_shape, y_shape, num_of_party=2):
    """
    Generate the matrix multiplication Beaver triple by trusted third party.

    First, ``num_of_param`` is added as the first dimension to the shape of x and y to generate an independent matrix
    for each triplet. The :meth:`~NssMPC.common.ring.ring_tensor.RingTensor.random` method is then used to generate
    two random matrices ``a`` and ``b``.
        * If the matrix is on the CPU: Calculate the value of ``c`` directly using the matrix multiplication operator **@**.
        * If the matrix is on the GPU: Call :func:`~NssMPC.common.utils.cuda_utils.cuda_matmul` to perform the calculation and then convert the result to a ring tensor.

    The function is divided into two cases based on the number of parties involved: 2 and 3 parties:
        * Two-party sharing: Use :class:`~NssMPC.crypto.primitives.arithmetic_secret_sharing.arithmetic_secret_sharing.ArithmeticSecretSharing` for normal secret sharing. The resulting shared results are converted to CPU format and stored in the MatmulTriples object.
        * Three-party sharing: Uses :class:`~NssMPC.crypto.primitives.arithmetic_secret_sharing.replicated_secret_sharing.ReplicatedSecretSharing` for replication sharing, which applies to three-party computing scenarios. The results are also converted to CPU format and stored in the RssMatmulTriples object.

    :param num_of_param: number of parameters
    :type num_of_param: int
    :param num_of_party: number of parties
    :type num_of_party: int
    :param x_shape: the shape of the matrix x
    :type x_shape: tuple
    :param y_shape: the shape of the matrix y
    :type y_shape: tuple

    Returns:
        the matrix multiplication Beaver triples
    """
    x_shape = [num_of_param] + list(x_shape)
    y_shape = [num_of_param] + list(y_shape)
    a = RingTensor.random(x_shape)
    b = RingTensor.random(y_shape)
    if a.device == 'cpu':
        c = a @ b
    else:
        c = cuda_matmul(a.tensor, b.tensor)
        c = RingTensor.convert_to_ring(c)

    # print("ori a", a)
    # print("ori b", b)
    # print("ori c", c)

    # todo 这里需要修改share_type 等千行分离以后搞
    if num_of_party == 2:
        from NssMPC.crypto.primitives.arithmetic_secret_sharing import ArithmeticSecretSharing
        a_list = ArithmeticSecretSharing.share(a, num_of_party)
        b_list = ArithmeticSecretSharing.share(b, num_of_party)
        c_list = ArithmeticSecretSharing.share(c, num_of_party)
        triples = []
        for i in range(num_of_party):
            triples.append(MatmulTriples())
            triples[i].a = a_list[i].to('cpu')
            triples[i].b = b_list[i].to('cpu')
            triples[i].c = c_list[i].to('cpu')

        return triples
    elif num_of_party == 3:
        # print("3pc_rss share")
        from NssMPC.crypto.primitives.arithmetic_secret_sharing import ReplicatedSecretSharing
        a_list = ReplicatedSecretSharing.share(a)
        b_list = ReplicatedSecretSharing.share(b)
        c_list = ReplicatedSecretSharing.share(c)
        triples = []
        for i in range(num_of_party):
            # print("*****************")
            # print(i)
            triples.append(RssMatmulTriples())
            triples[i].a = a_list[i].to('cpu')
            # print(triples[i].a)
            triples[i].b = b_list[i].to('cpu')
            # print(triples[i].b)
            triples[i].c = c_list[i].to('cpu')
            # print(triples[i].c)
            # print("*****************")
        return triples
    else:
        raise TypeError("Share type not supported")


class RssMatmulTriples(RssMulTriples):
    @staticmethod
    def gen(num_of_triples, x_shape=None, y_shape=None, num_of_party=3):
        """
        Generate multiplicative Beaver triples for RSS.

        Generate matrix multiplication triples using trusted third parties.

        :param num_of_triples: number of parameters
        :type num_of_triples: int
        :param num_of_party: number of parties
        :type num_of_party: int
        :param x_shape: the shape of the matrix x
        :type x_shape: tuple
        :param y_shape: the shape of the matrix y
        :type y_shape: tuple
        """
        return gen_matrix_triples_by_ttp(num_of_triples, x_shape, y_shape, num_of_party)
